Skip to content

compute loss only if training and update token metric naming#3293

Merged
NanoCode012 merged 15 commits into
axolotl-ai-cloud:mainfrom
ved1beta:training_loss
Dec 25, 2025
Merged

compute loss only if training and update token metric naming#3293
NanoCode012 merged 15 commits into
axolotl-ai-cloud:mainfrom
ved1beta:training_loss

Conversation

@ved1beta

@ved1beta ved1beta commented Dec 2, 2025

Copy link
Copy Markdown
Member

fixes #3291
image

Summary by CodeRabbit

  • Bug Fixes
    • Fixed token-per-second tracking to only occur when the model is actively training. Previously, tracking could run regardless of training state.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai

coderabbitai Bot commented Dec 2, 2025

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

The compute_loss method in the base trainer now conditionally enables token-per-second tracking based on model training state. Token-per-second metrics are now computed only when model.training is True, instead of whenever the include_tkps flag was set.

Changes

Cohort / File(s) Summary
Token-per-second tracking condition
src/axolotl/core/trainers/base.py
Modified compute_loss to enable tokens-per-second counting only during training mode (model.training is True), adding a conditional guard around previously unconditional token tracking logic

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~15–20 minutes

  • Extra attention areas:
    • Verify the impact on metrics reporting during evaluation and inference phases
    • Confirm that suppressing token-per-second tracking during non-training modes doesn't break logging or monitoring pipelines
    • Check for any side effects on the loss computation flow when model.training is False

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title partially addresses the changeset: it mentions computing loss only if training, which aligns with the main change (enabling tokens-per-second tracking only in training mode), but also mentions 'update token metric naming' which is not reflected in the provided summary of changes. Clarify whether 'update token metric naming' is part of this PR or should be removed from the title. If it is included, provide details on which metrics were renamed.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/axolotl/core/trainers/base.py (1)

351-369: Gating tkps tracking on model.training aligns with intent; consider edge‑cases and minor cleanup

Conditioning the token accounting on model.training achieves the goal of only tracking tokens during actual training steps and avoids polluting metrics during eval/inference. This should be compatible with the standard Trainer.train() / Trainer.evaluate() flow where the trainer toggles model.train() / model.eval().

Two minor points to keep in mind:

  • If you have any custom code paths that call compute_loss for “training-like” work while leaving model.training == False (e.g., manual scoring runs), tkps will now be skipped there; worth confirming that you don’t rely on tkps in those flows.
  • Optional micro‑refactor: you can avoid recomputing the mask sum when updating self.state.num_tokens by reusing num_tokens (or its .cpu() copy), e.g.:
-        if self.args.include_tkps and model.training:
-            inputs_key = "labels" if "labels" in inputs else "input_ids"
-            num_tokens = (inputs[inputs_key] != -100).sum()
+        if self.args.include_tkps and model.training:
+            inputs_key = "labels" if "labels" in inputs else "input_ids"
+            num_tokens = (inputs[inputs_key] != -100).sum()
+            num_tokens_cpu = num_tokens.cpu()
@@
-            if hasattr(self.state, "num_tokens"):
-                self.state.num_tokens = (
-                    self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
-                )
-            else:
-                self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
+            if hasattr(self.state, "num_tokens"):
+                self.state.num_tokens = self.state.num_tokens + num_tokens_cpu
+            else:
+                self.state.num_tokens = num_tokens_cpu
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c6ddcdd and 2a558f6.

📒 Files selected for processing (1)
  • src/axolotl/core/trainers/base.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.9.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.9.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)

@codecov

codecov Bot commented Dec 2, 2025

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 21.27660% with 37 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/utils/callbacks/tokens_per_second.py 27.58% 21 Missing ⚠️
src/axolotl/core/trainers/base.py 11.11% 16 Missing ⚠️

📢 Thoughts on this report? Let us know!

@xzuyn

xzuyn commented Dec 2, 2025

Copy link
Copy Markdown
Contributor

Solves the eval increase part, but the reset on resume still happens.

Eval every 10 steps.
image

@ved1beta

ved1beta commented Dec 3, 2025

Copy link
Copy Markdown
Member Author

ahh i thought it was just the recomputing here you goo ,
image

tokens_state = json.load(f)
state.total_tokens = torch.tensor(tokens_state.get("total_tokens", 0))
state.num_tokens = torch.tensor(tokens_state.get("num_tokens", 0))
LOG.info(f"Restored total_tokens: {state.total_tokens}")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this how we should store total_tokens? Are we able to inject it into the trainer state file that gets created in checkpoints?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ummm TrainerState is a dataclass, it only serializes defined fields

@xzuyn

xzuyn commented Dec 6, 2025

Copy link
Copy Markdown
Contributor

Resuming works too. [2025-12-06 13:23:07,478] [INFO] [axolotl.utils.callbacks.tokens_per_second] Restored total_tokens: 544092
image

@NanoCode012 NanoCode012 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in chat, let's refactor to use tokens/total and tokens/trainable. I believe self.num_tokens is redundant?

@ved1beta

ved1beta commented Dec 9, 2025

Copy link
Copy Markdown
Member Author

train eval working fine ,also checkpoints
image

@ved1beta ved1beta requested a review from NanoCode012 December 11, 2025 11:49
Comment thread src/axolotl/core/trainers/base.py Outdated
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
)
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_step"] = trainable_tokens.detach().cpu()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed as is unused (not logged and too similar to others)

Comment thread src/axolotl/core/trainers/base.py Outdated
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
logs["total_tokens"] = int(self.state.total_tokens.item())
logs["tokens/total"] = int(self.state.tokens["total"].item())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we missed log tokens/trainable

Comment on lines +105 to +109
if tokens and "total" in tokens:
logs["tokens/total"] = tokens["total"].item()

if tokens and "trainable" in tokens:
logs["tokens/trainable"] = tokens["trainable"].item()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this duplicate log of base.py L651-652?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yess. redundant ?

@ved1beta ved1beta requested a review from NanoCode012 December 19, 2025 09:38
@NanoCode012 NanoCode012 changed the title compute loss only if training compute loss only if training and update token metric naming Dec 23, 2025
@winglian

Copy link
Copy Markdown
Collaborator

can we add a CI similar to TestResumeLlama that on resume the total tokens is correct?

@ved1beta

Copy link
Copy Markdown
Member Author

Okay!

@NanoCode012 NanoCode012 added the hold don't merge this yet label Dec 25, 2025
@NanoCode012 NanoCode012 merged commit a6080df into axolotl-ai-cloud:main Dec 25, 2025
9 checks passed
winglian added a commit that referenced this pull request Dec 29, 2025
winglian added a commit that referenced this pull request Dec 30, 2025
* upgrade dependencies

* don't use reset sessions

* downgrade transformers, upgrade other deps

* upgrade bnb to 0.49.0

* restore s3 cache

* explicit use local files w hub

* decompress and strip top level dir

* use 2 levels for strip components

* try to preserve permissions for symlinks

* use updated tar

* fix #3293 for distributed

* downgrade bnb

* fast fail after 4

* fix total tokens device

* patch accelerate CP/SP (#3309)

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
@xzuyn

xzuyn commented Jan 5, 2026

Copy link
Copy Markdown
Contributor

tokens/total and tokens/trainable aren't being sent to wandb. It is also being shown as a float instead of an int.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

hold don't merge this yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

train/total_tokens increases from eval and resets when resuming from checkpoint

4 participants